import torch
import re
from PIL import Image
from abc import abstractproperty
import sys
import os.path as osp
from .base import BaseModel
from ..smp import *
from ..dataset import DATASET_TYPE
import copy


class Kosmos2(BaseModel):
    INSTALL_REQ = True
    INTERLEAVE = True

    def __init__(self, model_path="microsoft/kosmos-2-patch14-224", **kwargs):
        try:
            from transformers import AutoProcessor, Kosmos2ForConditionalGeneration
        except Exception as e:
            logging.critical(
                "Please install Transformers version 4.45.1 by running: pip install transformers==4.45.1"
            )
            raise e

        assert osp.exists(model_path) or splitlen(model_path) == 2

        self.model = Kosmos2ForConditionalGeneration.from_pretrained(
            model_path, torch_dtype=torch.float16
        ).to(torch.device("cuda"))
        self.processor = AutoProcessor.from_pretrained(model_path)

        default_kwargs = dict(max_new_tokens=512, use_cache=True)

        default_kwargs.update(kwargs)
        self.kwargs = default_kwargs
        warnings.warn(
            f"Following kwargs received: {self.kwargs}, will use as generation config. "
        )
        torch.cuda.empty_cache()

    def generate_inner(self, message, dataset=None):
        TASK_TOKEN = "<grounding> "
        QEUSTION_TOKEN = "Question: "
        ANSWER_TOKEN = "Answer: "
        images = []
        prompt = ""

        prompt += TASK_TOKEN
        for s in message:
            if s["type"] == "image":
                images.append(s["value"])
            elif s["type"] == "text":
                prompt += QEUSTION_TOKEN
                prompt += s["value"]
                prompt += ANSWER_TOKEN

        images = [Image.open(s) for s in images]
        inputs = self.processor(text=prompt, images=images[0], return_tensors="pt").to(
            torch.device("cuda")
        )

        generated_ids = self.model.generate(
            pixel_values=inputs["pixel_values"],
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            image_embeds=None,
            image_embeds_position_mask=inputs["image_embeds_position_mask"],
            **self.kwargs,
        )

        generated_text = self.processor.batch_decode(
            generated_ids, skip_special_tokens=True
        )[0]
        processed_text = self.processor.post_process_generation(
            generated_text, cleanup_and_extract=True
        )[0]
        cleaned_answer = re.sub(
            r"(Question:.*?Answer:|Question:.*)", "", processed_text
        ).strip()
        return cleaned_answer

    def use_custom_prompt(self, dataset):
        assert dataset is not None
        if listinstr(["MMMU"], dataset):
            return False
        if DATASET_TYPE(dataset) == "MCQ" or dataset == "MMVet":
            return True
        return False

    def build_prompt(self, line, dataset=None):
        assert dataset is None or isinstance(dataset, str)
        assert self.use_custom_prompt(dataset)
        tgt_path = self.dump_image(line, dataset)
        question = line["question"]
        if dataset == "MMVet":
            prompt = question + "\nAnswer the question directly. "
        elif DATASET_TYPE(dataset) == "MCQ":
            options = {
                cand: line[cand]
                for cand in string.ascii_uppercase
                if cand in line and not pd.isna(line[cand])
            }
            options_prompt = ""
            for key, item in options.items():
                options_prompt += f"{key}. {item}\n"

            hint = (
                line["hint"] if ("hint" in line and not pd.isna(line["hint"])) else None
            )
            prompt = f"Hint: {hint}\n" if hint is not None else ""
            prompt += f"{question}\n"
            prompt += (
                f"{options_prompt}\nAnswer with the option’s letter from the given choices directly. "
                if len(options)
                else "Answer the question directly. "
            )
        else:
            raise NotImplementedError

        message = [dict(type="text", value=prompt)]
        message.extend([dict(type="image", value=s) for s in tgt_path])
        return message
